{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3c5d72f4",
   "metadata": {},
   "source": [
    "# DistilBERT Classifier using Hugging Face `transformers`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9acb26bf-d3ab-44a3-b3a6-a4c18513d392",
   "metadata": {},
   "source": [
    "![](figures/finetuning-ii.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6fd9cda8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "92ea5612",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "033b75c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch       : 1.12.1\n",
      "transformers: 4.23.1\n",
      "datasets    : 2.6.1\n",
      "\n",
      "conda environment: dl-fundamentals\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark --conda -p torch,transformers,datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "602ba8a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cfd724d",
   "metadata": {},
   "source": [
    "# 1 Loading the Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd06d930",
   "metadata": {},
   "source": [
    "The IMDB movie review dataset consists of 50k movie reviews with sentiment label (0: negative, 1: positive)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60fe0b76",
   "metadata": {},
   "source": [
    "## 1a) Load from `datasets` Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "447e24bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import list_datasets, load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2baf2f16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list_datasets()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6310d5bf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset imdb (/home/raschka/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "298e3320a2444a62b242fe514f3d6ab1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['text', 'label'],\n",
      "        num_rows: 25000\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['text', 'label'],\n",
      "        num_rows: 25000\n",
      "    })\n",
      "    unsupervised: Dataset({\n",
      "        features: ['text', 'label'],\n",
      "        num_rows: 50000\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "imdb_data = load_dataset(\"imdb\")\n",
    "print(imdb_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "552bbb2e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': \"This film is terrible. You don't really need to read this review further. If you are planning on watching it, suffice to say - don't (unless you are studying how not to make a good movie).<br /><br />The acting is horrendous... serious amateur hour. Throughout the movie I thought that it was interesting that they found someone who speaks and looks like Michael Madsen, only to find out that it is actually him! A new low even for him!!<br /><br />The plot is terrible. People who claim that it is original or good have probably never seen a decent movie before. Even by the standard of Hollywood action flicks, this is a terrible movie.<br /><br />Don't watch it!!! Go for a jog instead - at least you won't feel like killing yourself.\",\n",
       " 'label': 0}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imdb_data[\"train\"][99]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40bdb9c5",
   "metadata": {},
   "source": [
    "## 1b) Load from local directory"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9103ec2d",
   "metadata": {},
   "source": [
    "The IMDB movie review set can be downloaded from http://ai.stanford.edu/~amaas/data/sentiment/. After downloading the dataset, decompress the files.\n",
    "\n",
    "A) If you are working with Linux or MacOS X, open a new terminal window cd into the download directory and execute\n",
    "\n",
    "    tar -zxf aclImdb_v1.tar.gz\n",
    "\n",
    "B) If you are working with Windows, download an archiver such as 7Zip to extract the files from the download archive."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac508bb8",
   "metadata": {},
   "source": [
    "C) Use the following code to download and unzip the dataset via Python"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "241ecc96",
   "metadata": {},
   "source": [
    "**Download the movie reviews**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "02aeade4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100% | 80.23 MB | 10.78 MB/s | 7.44 sec elapsed"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import tarfile\n",
    "import time\n",
    "import urllib.request\n",
    "\n",
    "source = \"http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
    "target = \"aclImdb_v1.tar.gz\"\n",
    "\n",
    "if os.path.exists(target):\n",
    "    os.remove(target)\n",
    "\n",
    "\n",
    "def reporthook(count, block_size, total_size):\n",
    "    global start_time\n",
    "    if count == 0:\n",
    "        start_time = time.time()\n",
    "        return\n",
    "    duration = time.time() - start_time\n",
    "    progress_size = int(count * block_size)\n",
    "    speed = progress_size / (1024.0**2 * duration)\n",
    "    percent = count * block_size * 100.0 / total_size\n",
    "\n",
    "    sys.stdout.write(\n",
    "        f\"\\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB \"\n",
    "        f\"| {speed:.2f} MB/s | {duration:.2f} sec elapsed\"\n",
    "    )\n",
    "    sys.stdout.flush()\n",
    "\n",
    "\n",
    "if not os.path.isdir(\"aclImdb\") and not os.path.isfile(\"aclImdb_v1.tar.gz\"):\n",
    "    urllib.request.urlretrieve(source, target, reporthook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2a867dcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isdir(\"aclImdb\"):\n",
    "\n",
    "    with tarfile.open(target, \"r:gz\") as tar:\n",
    "        tar.extractall()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9318d4d0",
   "metadata": {},
   "source": [
    "**Convert them to a pandas DataFrame and save them as CSV**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "464e587c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████| 50000/50000 [00:56<00:00, 887.11it/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from packaging import version\n",
    "from tqdm import tqdm\n",
    "\n",
    "# change the `basepath` to the directory of the\n",
    "# unzipped movie dataset\n",
    "\n",
    "basepath = \"aclImdb\"\n",
    "\n",
    "labels = {\"pos\": 1, \"neg\": 0}\n",
    "\n",
    "df = pd.DataFrame()\n",
    "\n",
    "with tqdm(total=50000) as pbar:\n",
    "    for s in (\"test\", \"train\"):\n",
    "        for l in (\"pos\", \"neg\"):\n",
    "            path = os.path.join(basepath, s, l)\n",
    "            for file in sorted(os.listdir(path)):\n",
    "                with open(os.path.join(path, file), \"r\", encoding=\"utf-8\") as infile:\n",
    "                    txt = infile.read()\n",
    "\n",
    "                if version.parse(pd.__version__) >= version.parse(\"1.3.2\"):\n",
    "                    x = pd.DataFrame(\n",
    "                        [[txt, labels[l]]], columns=[\"review\", \"sentiment\"]\n",
    "                    )\n",
    "                    df = pd.concat([df, x], ignore_index=False)\n",
    "\n",
    "                else:\n",
    "                    df = df.append([[txt, labels[l]]], ignore_index=True)\n",
    "                pbar.update()\n",
    "df.columns = [\"text\", \"label\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "02649593",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "np.random.seed(0)\n",
    "df = df.reindex(np.random.permutation(df.index))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59ca0386",
   "metadata": {},
   "source": [
    "**Basic datasets analysis and sanity checks**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c2db547a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class distribution:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([25000, 25000])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Class distribution:\")\n",
    "np.bincount(df[\"label\"].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a007e612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 173.0, 2470)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text_len = df[\"text\"].apply(lambda x: len(x.split()))\n",
    "text_len.min(), text_len.median(), text_len.max() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00f4b04d",
   "metadata": {},
   "source": [
    "**Split data into training, validation, and test sets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ff703901",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_shuffled = df.sample(frac=1, random_state=1).reset_index()\n",
    "\n",
    "df_train = df_shuffled.iloc[:35_000]\n",
    "df_val = df_shuffled.iloc[35_000:40_000]\n",
    "df_test = df_shuffled.iloc[40_000:]\n",
    "\n",
    "df_train.to_csv(\"train.csv\", index=False, encoding=\"utf-8\")\n",
    "df_val.to_csv(\"validation.csv\", index=False, encoding=\"utf-8\")\n",
    "df_test.to_csv(\"test.csv\", index=False, encoding=\"utf-8\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bd5f770",
   "metadata": {},
   "source": [
    "**Load the dataset via `load_dataset`**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a1aa66c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default-4edd77922b6957c3\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset csv/default to /home/raschka/.cache/huggingface/datasets/csv/default-4edd77922b6957c3/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7c594238c4348a19b2ea647480c311b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3abf4456eb5b49c59bbd3af8204dd681",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0 tables [00:00, ? tables/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:714: FutureWarning: the 'mangle_dupe_cols' keyword is deprecated and will be removed in a future version. Please take steps to stop the use of 'mangle_dupe_cols'\n",
      "  return pd.read_csv(xopen(filepath_or_buffer, \"rb\", use_auth_token=use_auth_token), **kwargs)\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/pandas/io/common.py:122: ResourceWarning: unclosed file <_io.BufferedReader name='/home/raschka/scratch/deeplearning-models/pytorch_ipynb/transformer/train.csv'>\n",
      "  self.handle.detach()\n",
      "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0 tables [00:00, ? tables/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:714: FutureWarning: the 'mangle_dupe_cols' keyword is deprecated and will be removed in a future version. Please take steps to stop the use of 'mangle_dupe_cols'\n",
      "  return pd.read_csv(xopen(filepath_or_buffer, \"rb\", use_auth_token=use_auth_token), **kwargs)\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/pandas/io/common.py:122: ResourceWarning: unclosed file <_io.BufferedReader name='/home/raschka/scratch/deeplearning-models/pytorch_ipynb/transformer/validation.csv'>\n",
      "  self.handle.detach()\n",
      "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0 tables [00:00, ? tables/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:714: FutureWarning: the 'mangle_dupe_cols' keyword is deprecated and will be removed in a future version. Please take steps to stop the use of 'mangle_dupe_cols'\n",
      "  return pd.read_csv(xopen(filepath_or_buffer, \"rb\", use_auth_token=use_auth_token), **kwargs)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset csv downloaded and prepared to /home/raschka/.cache/huggingface/datasets/csv/default-4edd77922b6957c3/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/pandas/io/common.py:122: ResourceWarning: unclosed file <_io.BufferedReader name='/home/raschka/scratch/deeplearning-models/pytorch_ipynb/transformer/test.csv'>\n",
      "  self.handle.detach()\n",
      "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c4940fc5f334d5091e8940325171c74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['index', 'text', 'label'],\n",
      "        num_rows: 35000\n",
      "    })\n",
      "    validation: Dataset({\n",
      "        features: ['index', 'text', 'label'],\n",
      "        num_rows: 5000\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['index', 'text', 'label'],\n",
      "        num_rows: 10000\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "imdb_dataset = load_dataset(\n",
    "    \"csv\",\n",
    "    data_files={\n",
    "        \"train\": \"train.csv\",\n",
    "        \"validation\": \"validation.csv\",\n",
    "        \"test\": \"test.csv\",\n",
    "    },\n",
    ")\n",
    "\n",
    "print(imdb_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "846d83b1",
   "metadata": {},
   "source": [
    "# 2 Tokenization and Numericalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "5ea762ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokenizer input max length: 512\n",
      "Tokenizer vocabulary size: 30522\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
    "print(\"Tokenizer input max length:\", tokenizer.model_max_length)\n",
    "print(\"Tokenizer vocabulary size:\", tokenizer.vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8432c15c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_text(batch):\n",
    "    return tokenizer(batch[\"text\"], truncation=True, padding=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0bb392cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "676e50abbf2a4bc684c308a176cf123b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "efc95a53bc294714b8b8a19eef857c3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "520e97acf0ef45b9af84ab075b80f3a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "6d4103c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "del imdb_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfeb1553",
   "metadata": {},
   "source": [
    "# 3 Finetuning DistilBERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9f2c474d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias']\n",
      "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier.bias', 'pre_classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\n",
    "    \"distilbert-base-uncased\", num_labels=2)\n",
    "model.to(device);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "f3439723-46ef-4fdf-8040-6fc93e43f2ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_10843/4127383031.py:4: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
      "  metric = load_metric('accuracy')\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
      "  if not hasattr(tensorboard, \"__version__\") or LooseVersion(\n",
      "The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: index, text. If index, text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "***** Running training *****\n",
      "  Num examples = 35000\n",
      "  Num Epochs = 3\n",
      "  Instantaneous batch size per device = 8\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 32\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 3282\n",
      "You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3282' max='3282' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3282/3282 22:44, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.220200</td>\n",
       "      <td>0.195231</td>\n",
       "      <td>{'accuracy': 0.9272}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.159100</td>\n",
       "      <td>0.198446</td>\n",
       "      <td>{'accuracy': 0.9272}</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.124900</td>\n",
       "      <td>0.202979</td>\n",
       "      <td>{'accuracy': 0.9312}</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to distilbert-v1/checkpoint-500\n",
      "Configuration saved in distilbert-v1/checkpoint-500/config.json\n",
      "Model weights saved in distilbert-v1/checkpoint-500/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-v1/checkpoint-500/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-v1/checkpoint-500/special_tokens_map.json\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Saving model checkpoint to distilbert-v1/checkpoint-1000\n",
      "Configuration saved in distilbert-v1/checkpoint-1000/config.json\n",
      "Model weights saved in distilbert-v1/checkpoint-1000/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-v1/checkpoint-1000/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-v1/checkpoint-1000/special_tokens_map.json\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: index, text. If index, text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 5000\n",
      "  Batch size = 32\n",
      "Trainer is attempting to log a value of \"{'accuracy': 0.9272}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Saving model checkpoint to distilbert-v1/checkpoint-1500\n",
      "Configuration saved in distilbert-v1/checkpoint-1500/config.json\n",
      "Model weights saved in distilbert-v1/checkpoint-1500/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-v1/checkpoint-1500/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-v1/checkpoint-1500/special_tokens_map.json\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Saving model checkpoint to distilbert-v1/checkpoint-2000\n",
      "Configuration saved in distilbert-v1/checkpoint-2000/config.json\n",
      "Model weights saved in distilbert-v1/checkpoint-2000/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-v1/checkpoint-2000/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-v1/checkpoint-2000/special_tokens_map.json\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: index, text. If index, text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 5000\n",
      "  Batch size = 32\n",
      "Trainer is attempting to log a value of \"{'accuracy': 0.9272}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Saving model checkpoint to distilbert-v1/checkpoint-2500\n",
      "Configuration saved in distilbert-v1/checkpoint-2500/config.json\n",
      "Model weights saved in distilbert-v1/checkpoint-2500/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-v1/checkpoint-2500/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-v1/checkpoint-2500/special_tokens_map.json\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "Saving model checkpoint to distilbert-v1/checkpoint-3000\n",
      "Configuration saved in distilbert-v1/checkpoint-3000/config.json\n",
      "Model weights saved in distilbert-v1/checkpoint-3000/pytorch_model.bin\n",
      "tokenizer config file saved in distilbert-v1/checkpoint-3000/tokenizer_config.json\n",
      "Special tokens file saved in distilbert-v1/checkpoint-3000/special_tokens_map.json\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "The following columns in the evaluation set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: index, text. If index, text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 5000\n",
      "  Batch size = 32\n",
      "Trainer is attempting to log a value of \"{'accuracy': 0.9312}\" of type <class 'dict'> for key \"eval/accuracy\" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from transformers import Trainer, TrainingArguments\n",
    "from datasets import load_metric\n",
    "\n",
    "metric = load_metric('accuracy')\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    acc = metric.compute(predictions=predictions, references=labels)\n",
    "    return {\"accuracy\": acc}\n",
    "\n",
    "batch_size = 10\n",
    "trainer_args = TrainingArguments(output_dir=\"distilbert-v1\",\n",
    "                                 num_train_epochs=3,\n",
    "                                 evaluation_strategy=\"epoch\",\n",
    "                                 learning_rate=1e-5)\n",
    "\n",
    "trainer = Trainer(model=model,\n",
    "                  args=trainer_args,\n",
    "                  compute_metrics=compute_metrics,\n",
    "                  train_dataset=imdb_tokenized[\"train\"],\n",
    "                  eval_dataset=imdb_tokenized[\"validation\"],\n",
    "                  tokenizer=tokenizer)\n",
    "\n",
    "trainer.train();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "86041e3a-0ebd-45b9-8fce-f93df5e64e7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the test set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: index, text. If index, text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Prediction *****\n",
      "  Num examples = 35000\n",
      "  Batch size = 32\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'test_loss': 0.09555176645517349,\n",
       " 'test_accuracy': {'accuracy': 0.9704571428571429},\n",
       " 'test_runtime': 186.0604,\n",
       " 'test_samples_per_second': 188.111,\n",
       " 'test_steps_per_second': 5.88}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = trainer.predict(imdb_tokenized[\"train\"])\n",
    "outputs.metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "11b8e5d5-6e93-4cdb-8d76-43ff5216a1cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the test set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: index, text. If index, text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Prediction *****\n",
      "  Num examples = 5000\n",
      "  Batch size = 32\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'test_loss': 0.2029794603586197,\n",
       " 'test_accuracy': {'accuracy': 0.9312},\n",
       " 'test_runtime': 26.8347,\n",
       " 'test_samples_per_second': 186.326,\n",
       " 'test_steps_per_second': 5.851}"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = trainer.predict(imdb_tokenized[\"validation\"])\n",
    "outputs.metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7b72238b-976f-4e63-ba2f-6a78f50f3530",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the test set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: index, text. If index, text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Prediction *****\n",
      "  Num examples = 10000\n",
      "  Batch size = 32\n",
      "/home/raschka/miniforge3/envs/dl-fundamentals/lib/python3.9/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'test_loss': 0.21524782478809357,\n",
       " 'test_accuracy': {'accuracy': 0.926},\n",
       " 'test_runtime': 53.8797,\n",
       " 'test_samples_per_second': 185.599,\n",
       " 'test_steps_per_second': 5.809}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = trainer.predict(imdb_tokenized[\"test\"])\n",
    "outputs.metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c089f675-c6a0-4786-85e9-0c341feea67e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
